from django.http import JsonResponse, HttpResponse

from djmockserver.mock.serializer import MockItem, RequestSerializer
import djmockserver.mock.extend_func.extend_functions as extfunc
from djmockserver.mock.extend_func.extend_functions import *
from djmockserverproject import settings

import ast
import json
import logging
import copy
import requests


logger = logging.getLogger(__name__)


class checkVisitor(ast.NodeVisitor):
    """
    遍历语法树,过滤非白名单的函数和变量
    """
    # 添加python的内置函数到执行白名单
    _add_func = settings.ADD_FUNC[:] if hasattr(settings, 'ADD_FUNC') else []
    _safe_func = list(extfunc.__all__[:]) +  _add_func

    def __init__(self):
        self._safe = True
        self.safe_vars = {}
        super().__init__()

    def safe(self, expression):
        try:
            node = ast.parse(str(expression))
        except Exception:
            return False
        # 检查是否安全
        self.visit(node)
        return self._safe

    def visit_Name(self, node):
        "限制使用非安全函数和变量"
        if node.id in self.safe_vars or node.id in self._safe_func:
            pass
        else:
            logger.warning('unsafe var or function expression : {}'.format(node.id))
            self._safe = False
            return

        for field, value in ast.iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    try:
                        self.visit(item)
                    except Exception:
                        pass
            elif isinstance(value, ast.AST):
                self.visit(value)

    def visit_Lambda(self, node):
        "限制执行lambda函数"
        self._safe = False
        return


class LocalEnv:
    def __init__(self, env_dict: dict, visitobj: checkVisitor):
        """
        :param env_dict:mock文件中读取的变量字典
        """
        globals().update(env_dict)
        self._vars = {}

        for var_name, var_value in env_dict.items():
            # 检查表达式是否安全
            if visitobj.safe(var_value):
                try:
                    # 添加到本局部变量，后面eval可用前面添加的变量
                    locals()[var_name] = eval(var_value)
                    self._vars[var_name] = locals()[var_name]
                except Exception as e:
                    logger.warning('var <{}> parse error : {}'.format(var_value, e))
                    locals()[var_name] = var_value
                    self._vars[var_name] = var_value
            else:
                logger.warning('expression <{}> is unsafe, return raw data.'.format(var_value))
                locals()[var_name] = var_value
                self._vars[var_name] = var_value

    def set(self, var_name, var_value):
        self._vars[var_name] = var_value

    def get(self, var_name, default=None):
        return self._vars.get(var_name, default=default)

    @property
    def vars(self):
        return self._vars


def eval_resp(resp_json, env_vars: dict, visitobj: checkVisitor):
    "解析mock响应内容，执行相关变量或函数"
    def run(resp_json):
        if isinstance(resp_json, (dict, list)):
            iteration = resp_json.items() if isinstance(resp_json, dict) else enumerate(resp_json)
            for key, value in iteration:
                resp_json[key] = run(value)
            return resp_json
        elif isinstance(resp_json, (float, int)):
            return resp_json
        elif isinstance(resp_json, str):
            try:
                locals().update(env_vars)  # 加载局部变量
                result = eval(resp_json) if visitobj.safe(resp_json) else resp_json
                return str(result) if isinstance(result, (float, int)) else result
            except Exception as e:
                logger.warning('expression <{}> eval failed, return raw data. exception detail : {}'.format(resp_json, e))
                return resp_json
        else:
            raise Exception('data type is {}, but require dict/list/int/float/str'.format(type(resp_json)))

    try:
        return run(resp_json)
    except Exception as e:
        logger.error('resp json parse error, detail : {}'.format(e))
        return resp_json


def forward_request(req):
    logger.info('{0} the request is going to forward {0}'.format('*'*10))
    logger.info('forward the request : {}'.format(req))

    path = req._path
    logger.info('request full path : {}'.format(path))

    remote_host = settings.REMOTE_HOST
    url = '{}{}'.format(remote_host, path)
    logger.info('url that to be forwarding : {}'.format(url))

    method = req.method
    logger.info('request method : {}'.format(method))

    # content_type = req._content_type
    # logger.info('请求数据类型：{}'.format(content_type))

    del req._headers['Host']
    headers = req._headers
    logger.info('request headers : {}'.format(headers))

    cookies = req.cookies
    logger.info('request COOKIES : {}'.format(cookies))

    body = req._body
    logger.info('request body : {}'.format(body))

    try:
        resp = requests.request(method=method, url=url, headers=headers, cookies=cookies, data=body, timeout=60)

        logger.info('{0} response {0}'.format('*'*10))
        logger.info('code : {}'.format(resp.status_code))
        logger.info('content : {}'.format(resp.text))
        logger.info('response headers : {}'.format(resp.headers))

        response = HttpResponse()

        response.content = resp.content
        response.status_code = resp.status_code
        response['Content-Type'] = resp.headers.get('Content-Type', 'text/plain')

        return response
    except Exception as e:
        logger.warning('forward error : {}'.format(e))
        return HttpResponse(status=500)


class ResponseMaker:
    """
    构造响应,如响应中带变量和函数,需要执行
    """
    def __init__(self, request: RequestSerializer, mockitem: MockItem):
        self._request = request
        # mockitem缓存在loader中,如果mock.json未修改或服务未重启,则该对象是唯一的
        # 如直接操作该对象,每次相同请求都处理同一对象,所以需要copy一个新对象来操作
        self._mockitem = copy.deepcopy(mockitem)
        self._req = self._request.req # 取实际请求发送的数据
        self._resp = self._mockitem.resp
        self._uri = self._mockitem.uri

        self._status_code = self._resp.get('status') or 200
        self._headers = self._resp.get('headers') or {}
        self._text = self._resp.get('text') or ''
        self._json = self._resp.get('json') or {}
        self._vars = self._resp.get('vars') or {}
        # 请求req添加到变量中
        self._vars['req'] = self._req
        self._vars['path'] = self._request._path # full_path

    @property
    def response(self):
        logger.info('matched data : {}'.format(self._resp))

        # 匹配不到结果且配置的转发地址则转发请求并返回结果
        if not self._resp and hasattr(settings, 'REMOTE_HOST') and getattr(settings, 'REMOTE_HOST'):
            return forward_request(req=self._request)

        resp = HttpResponse()
        resp.status_code = self._status_code

        if self._headers:
            for header_name, header_value in self._headers.items():
                resp[header_name] = header_value

        # 语法遍历器更新安全变量
        visitobj = checkVisitor()
        visitobj.safe_vars = self._vars

        # 加载mock变量
        localenv = LocalEnv(self._vars, visitobj)

        # 加载settings中全局变量
        if hasattr(settings, 'GLOBAL_VARS'):
            for key, value in getattr(settings, 'GLOBAL_VARS').items():
                localenv.set(key, value)

        logger.info('parsed mock vars : {}'.format(localenv.vars))

        # 语法遍历器更新安全变量
        visitobj.safe_vars = localenv.vars

        if self._resp:
            if self._text:
                resp.content = eval_resp(self._text, localenv.vars, visitobj)
            if self._json:
                resp.content = json.dumps(eval_resp(self._json, localenv.vars, visitobj))
                resp["Content-Type"] = "application/json; charset=utf-8"
            return resp
        else:
            return JsonResponse({'message': 'uri does not match'}, status=404)

